package cn.easyutil.task.delay;

import cn.easyutil.task.delay.beans.TaskDefinition;
import cn.easyutil.task.delay.handler.DelayTaskRetryErrorHandler;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.exceptions.ExceptionUtil;
import cn.hutool.json.JSONUtil;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.*;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;

/**
 * redis作为延时队列的实现
 */
public abstract class AbstractRedisDelayTask implements DelayTask{

    //锁
    private final Object lock = new Object();

    private final Object waitLock = new Object();

    //任务转移记录缓存key （从延时队列转移到消费队列 或 从消费队列到本地线程）
    private static String transferRecordsKeyOfSet = "delay_task_transfer_record_";
    //全量任务保存的key
    private static String allValueKeyOfHash = "delay_task_all_value_";
    //待执行任务的延时队列
    private static String delayKeyOfZSet = "delay_task_delay_key_";
    //执行任务的队列
    private static String queueKeyOfList = "delay_task_queue_key_";
    //任务执行异常后将数据存放的key
    private static String errorTaskKeyOfList = "delay_task_error_key_";
    //分布式锁
    private static String lockKeyOfStr = "delay_task_lock_key_";

    private StringRedisTemplate template;

    //异常后的重试最大次数
    private int retryCount;

    private Collection<DelayTaskRetryErrorHandler> handlers;

    //线程池配置
    private ThreadPoolExecutor threadPool;

    private AbstractRedisDelayTask(){}

    public AbstractRedisDelayTask(StringRedisTemplate template) {
        if(template == null){
            throw new NullPointerException("stringRedisTemplate 不能为空");
        }
        String providerName = providerName = this.getClass().getSimpleName();
        this.template = template;
        transferRecordsKeyOfSet = transferRecordsKeyOfSet+providerName;
        allValueKeyOfHash = allValueKeyOfHash+providerName;
        delayKeyOfZSet = delayKeyOfZSet+providerName;
        queueKeyOfList = queueKeyOfList+providerName;
        errorTaskKeyOfList = errorTaskKeyOfList+providerName;
        lockKeyOfStr = lockKeyOfStr+providerName;
        //初始化handler
        handlers = new ArrayList<>();
        handlers.add(new DelayTaskRetryErrorHandler() {
            @Override
            public boolean supports(TaskDefinition definition) {
                return true;
            }
            @Override
            public void process(TaskDefinition definition) {
                //将任务放在redis中
                template.opsForList().leftPush(errorTaskKeyOfList, JSONUtil.toJsonStr(definition));
            }
        });
    }

    public void setThreadPool(int coreSize, int activeCount) {
        setThreadPool(new ThreadPoolExecutor(coreSize, activeCount, 30, TimeUnit.SECONDS, new LinkedBlockingQueue<>(1000)));
    }
    public void setThreadPool(ThreadPoolExecutor threadPool) {
        int corePoolSize = threadPool.getCorePoolSize();
        int maximumPoolSize = threadPool.getMaximumPoolSize();
        if(corePoolSize <= 2){
            threadPool.setCorePoolSize(2+corePoolSize);
        }
        if(maximumPoolSize <= 2){
            threadPool.setMaximumPoolSize(2+maximumPoolSize);
        }
        this.threadPool = threadPool;
    }
    public void setRetryCount(int retryCount){
        this.retryCount = Math.abs(retryCount);
    }
    public void addHandler(DelayTaskRetryErrorHandler handler){
        if(handler != null){
            this.handlers.add(handler);
        }
    }


    @Override
    public boolean createTask(Long executeTime, String taskValue) {
        TaskDefinition task = TaskDefinition.create(executeTime, taskValue);
        return createTask(task);
    }

    @Override
    public boolean createTask(TaskDefinition task) {
        if(task == null){
            throw new NullPointerException("缺少任务信息");
        }
        if(StringUtils.isEmpty(task.getValue())){
            throw new NullPointerException("缺少任务信息");
        }
        if (task.getExecuteTime() == null) {
            task.setExecuteTime(System.currentTimeMillis());
        }
        task.setTaskClass(this.getClass().getCanonicalName());
        /*
         * 1. 先将任务丢入到全部任务中
         * 2. 如果是延时任务，则丢入延时队列
         * 3. 如果非延时任务，则丢入任务队列
         */
        DefaultRedisScript<Object> script = new DefaultRedisScript<>();
        script.setResultType(Object.class);
        script.setScriptText("" +
                " redis.call('HSET',KEYS[1],ARGV[1],ARGV[2]); " +
                " if (tonumber(ARGV[3]) > 0) " +
                " then " +
                " redis.call('ZADD',KEYS[2],ARGV[4],ARGV[1]); " +
                " else " +
                " redis.call('LPUSH',KEYS[3],ARGV[1]); " +
                " end ");
        List<String> keys = new ArrayList<>();
        keys.add(allValueKeyOfHash);
        keys.add(delayKeyOfZSet);
        keys.add(queueKeyOfList);
        String args1 = task.getValue();
        String args2 = JSONUtil.toJsonStr(task);
        String args3 = "1";
        String args4 = task.getExecuteTime().toString();
        this.template.execute(script, keys, args1, args2, args3, args4);
        try {
            //唤醒线程
            synchronized (lock) {
                lock.notifyAll();
            }
        } catch (Exception ignore) {}
        return true;
    }

    @Override
    public boolean removeTask(String... tasks) {
        if (tasks.length == 0) {
            return true;
        }
        //从延时队列和全部任务中进行删除
        this.template.opsForHash().delete(allValueKeyOfHash, tasks);
        this.template.opsForZSet().remove(delayKeyOfZSet, tasks);
        this.template.opsForSet().remove(transferRecordsKeyOfSet, tasks);
        return true;
    }

    @Override
    public boolean removeTask(Collection<String> tasks) {
        if(CollectionUtils.isEmpty(tasks)){
            return true;
        }
        String[] array = new String[tasks.size()];
        array = tasks.toArray(array);
        return removeTask(array);
    }

    @Override
    public boolean supportsException(Exception e) {
        return false;
    }

    @Override
    public void start() {
        if(this.threadPool == null){
            throw new NullPointerException("任务处理线程池不能为空");
        }
        recover();
        scan();
        consume();
    }

    @Override
    public void stop() {
        if(this.threadPool != null){
            this.threadPool.shutdown();
        }
    }

    /**
     * 扫描器启动，循环执行。
     */
    private void scan() {
        this.threadPool.execute(()->{
            Exception ex;
            while (true) {
                try {
                    long now = System.currentTimeMillis();
                    int transfer = transferToQueueByZSet(delayKeyOfZSet, queueKeyOfList, now);
                    if (transfer == 0) {
                        continue;
                    }
                    //等待时间
                    long waitTime = 3000;
                    if (transfer > 0) {
                        waitTime = transfer;
                    }
                    synchronized (lock) {
                        try {
                            lock.wait(Math.min(waitTime, 3000));
                        } catch (InterruptedException ignored) {}
                    }
                }catch (Exception e) {
                    ex = e;
                    break;
                }
            }
            throw new RuntimeException("任务扫描异常中止",ex);
        });
    }

    /**
     * 从延时队列ZSet中将任务转移到List队列
     *
     * @param zSetKey   延时队列的key
     * @param queueKey  任务队列的key
     * @param timeStamp 对比要不要转移任务的时间戳
     * @return 返回值 0-操作成功  -1-延时队列中无可用数据  n-排名最小的那个需要等待的时间
     */
    private int transferToQueueByZSet(String zSetKey, String queueKey, long timeStamp) {
        DefaultRedisScript<Long> script = new DefaultRedisScript<>();
        script.setResultType(Long.class);
        script.setScriptText("" +
                " local zList; " +
                " zList = redis.call('ZRANGE',KEYS[1],0,0,'WITHSCORES'); " +
                " if (#zList == 0) " +
                " then return -1 " +
                " else " +
                " local value = zList[1]; " +
                " local executeTime = zList[2]; " +
                " if (tonumber(executeTime) > tonumber(ARGV[1])) " +
                " then return tonumber(executeTime) - tonumber(ARGV[1]); end;" +
                " redis.call('LPUSH',KEYS[2],value); " +
                " redis.call('ZREM',KEYS[1],value);" +
                " return 0; " +
                " end " +
                "");
        List<String> keys = new ArrayList<>();
        keys.add(zSetKey);
        keys.add(queueKey);
        Long result = this.template.execute(script, keys, String.valueOf(timeStamp));
        //返回值 0-操作成功  -1-延时队列中无可用数据  n-排名最小的那个需要等待的时间
        return result == null ? 0 : result.intValue();
    }

    /**
     * 系统重启后，对意外丢掉【如系统重启、崩溃等】的任务进行重拾处理
     */
    private void recover() {
        //先拿到意外未执行完成的任务
        Set<String> members = this.template.opsForSet().members(transferRecordsKeyOfSet);
        if (CollectionUtils.isEmpty(members)) {
            return;
        }
        List<List<String>> partition = ListUtil.split(new ArrayList<>(members), 500);
        this.threadPool.execute(() -> partition.forEach(val -> {
            String[] array = new String[val.size()];
            array = val.toArray(array);
            this.template.opsForList().leftPushAll(queueKeyOfList,array);
        }));
    }

    /**
     * 消费者启动，循环执行。
     */
    private void consume() {
        int maxPoolSize = this.threadPool.getMaximumPoolSize();
        this.threadPool.execute(() -> {
            Exception ex;
            while (true) {
                try {
                    //如果当前活跃线程达到了最大线程数的80%，则暂停新增线程
                    if (maxPoolSize * 0.9 < this.threadPool.getActiveCount()) {
                        //先暂时休息一下
                        synchronized (waitLock) {
                            try {
                                waitLock.wait(1000L);
                                continue;
                            } catch (InterruptedException ignored) {
                                continue;
                            }
                        }
                    }
                    String taskValue = consumeOne();
                    //队列中无数据，暂时休息
                    if (StringUtils.isEmpty(taskValue)) {
                        synchronized (waitLock) {
                            try {
                                waitLock.wait(1000L);
                                continue;
                            } catch (InterruptedException ignored) {
                                continue;
                            }
                        }
                    }
                    Object taskInfo = this.template.opsForHash().get(allValueKeyOfHash, taskValue);
                    if (taskInfo == null || StringUtils.isEmpty(taskInfo.toString())) {
                        //清除任务消费记录
                        taskOverWithClear(taskValue);
                        continue;
                    }
                    TaskDefinition task = JSONUtil.toBean(taskInfo.toString(), TaskDefinition.class);
                    this.threadPool.execute(() -> {
                        try {
                            //执行期间加锁
                            Boolean aBoolean = this.template.opsForValue().setIfAbsent(lockKeyOfStr + task.getValue(), "0", 5, TimeUnit.SECONDS);
                            if (aBoolean != null && aBoolean) {
                                execute(task);
                                //清除任务消费记录
                                taskOverWithClear(task.getValue());
                            }
                        } catch (Exception e) {
                            //先识别一下错误是否需要处理
                            if(!supportsException(e)){
                                //任务执行报错的话重新放入到任务队列中
                                failProcess(task, e);
                            }
                        } finally {
                            this.template.delete(lockKeyOfStr + task.getValue());
                        }
                    });
                } catch (Exception e) {
                    ex = e;
                    break;
                }
            }
            throw new RuntimeException("任务消费异常中止",ex);
        });
    }

    /**
     * 【取出动作】会阻塞30秒，若取到值，则进行消费。
     */
    private String consumeOne() {
        DefaultRedisScript<String> script = new DefaultRedisScript<>();
        script.setResultType(String.class);
        script.setScriptText("" +
                " local task = redis.call('LPOP',KEYS[1]); " +
                " if (task ~= nil and type(task) == 'string') " +
                " then " +
                " redis.call('SADD',KEYS[2],task) " +
                " return task; " +
                " end " +
                " return nil; " +
                "");
        List<String> keys = new ArrayList<>();
        keys.add(queueKeyOfList);
        keys.add(transferRecordsKeyOfSet);
        //获取弹出的数据，如果为空则说明队列中无数据
        return this.template.execute(script, keys);
    }

    /**
     * 任务执行完成后进行清理
     */
    private void taskOverWithClear(String value) {
        //分别清理全部任务和中转队列
        DefaultRedisScript<Object> script = new DefaultRedisScript<>();
        script.setResultType(Object.class);
        script.setScriptText("" +
                " redis.call('HDEL',KEYS[1],ARGV[1]); " +
                " redis.call('SREM',KEYS[2],ARGV[1]); " +
                " ");
        List<String> keys = new ArrayList<>();
        keys.add(allValueKeyOfHash);
        keys.add(transferRecordsKeyOfSet);
        this.template.execute(script, keys, value);
    }

    /**
     * 任务执行报错后置处理
     *
     * @param task 队列中的任务的值
     */
    protected void failProcess(TaskDefinition task, Exception e) {
        e.printStackTrace();
        Integer count = task.getExecuteCount();
        if (count == null) {
            task.setExecuteCount(1);
        }
        //清除任务
        taskOverWithClear(task.getValue());
        //如果超过最大重试次数，则停止重试，并将该任务放入执行异常的队列
        if (task.getExecuteCount() > this.retryCount) {
            task.setTaskClass(this.getClass().getCanonicalName());
            task.setError(ExceptionUtil.stacktraceToString(e));
            task.setErrorTime(DateUtil.formatDateTime(new Date()));
            retryError(task);
            return;
        }
        //未超过最大重试次数，则丢入延时队列继续执行
        if (task.getRetryDelayTime() <= 0) {
            task.setRetryDelayTime(2000);
        }
        task.setExecuteTime(System.currentTimeMillis()+task.getRetryDelayTime());
        task.setRetryDelayTime(task.getRetryDelayTime() * 2);
        task.setExecuteCount(task.getExecuteCount() + 1);
        createTask(task);
    }

    /**
     * 重试完也失败的处理
     */
    private void retryError(TaskDefinition task) {
        if(CollectionUtils.isEmpty(this.handlers)){
            return ;
        }
        this.handlers.forEach(handler -> {
            if(handler.supports(task)){
                handler.process(task);
            }
        });
    }
}
